{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lzvsh1Qf5Qrq"
   },
   "source": [
    "# Waymo Open Dataset 3D Camera-Only Detection Tutorial\n",
    "\n",
    "- Website: https://waymo.com/open\n",
    "- GitHub: https://github.com/waymo-research/waymo-open-dataset\n",
    "- Challenge: https://waymo.com/intl/en_us/open/challenges/2022/3d-camera-only-detection/\n",
    "\n",
    "This tutorial demonstrates how to interpret the camera-synced labels. Visit the [Waymo Open Dataset Website](https://waymo.com/open) to download the full dataset.\n",
    "\n",
    "To use, open this notebook in [Colab](https://colab.research.google.com).\n",
    "\n",
    "Uncheck the box \"Reset all runtimes before running\" if you run this colab directly from the remote kernel. Alternatively, you can make a copy before trying to run it by following \"File > Save copy in Drive ...\".\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FBkRKAyw5Ty7"
   },
   "source": [
    "# Package Installation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "phBXqdKNlRNR"
   },
   "source": [
    "To run the colab against dataset files stored on the local machine, we recommend launching a local runtime in your python environment with `waymo-open-dataset` installed. Please follow the instructions [here](https://research.google.com/colaboratory/local-runtimes.html).\n",
    "\n",
    "Otherwise, you can follow the instructions in [tutorial.ipynb](https://github.com/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial.ipynb). Please pip install `waymo-open-dataset-tf-2-12-0==1.6.7` later than `1.4.6` and upload required segments to colab with any method in [io.ipynb](https://colab.research.google.com/notebooks/io.ipynb). Note that the upload could take a while due to large file sizes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jbivf1ML5Voi"
   },
   "source": [
    "# Imports and Global Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jTqJGjoX5Q57"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import patches\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "if not tf.executing_eagerly():\n",
    "  tf.compat.v1.enable_eager_execution()\n",
    "\n",
    "from waymo_open_dataset import dataset_pb2 as open_dataset\n",
    "from waymo_open_dataset import label_pb2\n",
    "from waymo_open_dataset.camera.ops import py_camera_model_ops\n",
    "from waymo_open_dataset.metrics.ops import py_metrics_ops\n",
    "from waymo_open_dataset.metrics.python import config_util_py as config_util\n",
    "from waymo_open_dataset.protos import breakdown_pb2\n",
    "from waymo_open_dataset.protos import metrics_pb2\n",
    "from waymo_open_dataset.protos import submission_pb2\n",
    "from waymo_open_dataset.utils import box_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Sd-C6LAO5MP8"
   },
   "outputs": [],
   "source": [
    "# Data location - please edit. Should point to a tfrecord containing tf.Example\n",
    "# protos as downloaded from the Waymo Open Dataset website.\n",
    "\n",
    "FILENAME = '/content/waymo-od/tutorial/.../training/segment-1305342127382455702_3720_000_3740_000_with_camera_labels.tfrecord'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9QHlEkkp5vU1"
   },
   "source": [
    "# Read Frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2fkmVNp6hAVV"
   },
   "outputs": [],
   "source": [
    "dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')\n",
    "dataset_iter = dataset.as_numpy_iterator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YXhhbCBvhMCW"
   },
   "outputs": [],
   "source": [
    "# Get a frame from the segment. Run this cell again to obtain next frame.\n",
    "data = next(dataset_iter)\n",
    "frame = open_dataset.Frame()\n",
    "frame.ParseFromString(data)\n",
    "print(frame.context.name, frame.timestamp_micros)\n",
    "\n",
    "FILTER_AVAILABLE = any(\n",
    "    [label.num_top_lidar_points_in_box > 0 for label in frame.laser_labels])\n",
    "\n",
    "if not FILTER_AVAILABLE:\n",
    "  print('WARNING: num_top_lidar_points_in_box does not seem to be populated. '\n",
    "        'Make sure that you are using an up-to-date release (V1.3.2 or later) '\n",
    "        'to enable improved filtering of occluded objects.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fwfTFg8G9y6Q"
   },
   "source": [
    "# Visualize Camera Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3qGWRwOQZIai"
   },
   "outputs": [],
   "source": [
    "def show_camera_image(camera_image, layout):\n",
    "  \"\"\"Display the given camera image.\"\"\"\n",
    "  ax = plt.subplot(*layout)\n",
    "  plt.imshow(tf.image.decode_jpeg(camera_image.image))\n",
    "  plt.title(open_dataset.CameraName.Name.Name(camera_image.name))\n",
    "  plt.grid(False)\n",
    "  plt.axis('off')\n",
    "  return ax\n",
    "\n",
    "\n",
    "plt.figure(figsize=(25, 20))\n",
    "\n",
    "for index, image in enumerate(frame.images):\n",
    "  _ = show_camera_image(image, [3, 3, index + 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FKIOHn6m568Y"
   },
   "source": [
    "# Inspect 3D Labels\n",
    "\n",
    "In addition to the original LiDAR-synced boxes (`box`), there are now camera-synced boxes (`camera_synced_box`) available, which can be useful for camera-centric modeling. Camera-synced boxes are designed to more accurately reflect the state of objects when they were observed by the camera, which is different from the time they were observed by our LiDAR sensors. During the time period between the sensors' respective captures of different parts of the surrounding environment, both the autonomously driven vehicle (ADV) and dynamic objects move, which is reflected by these boxes having slightly different positions.\n",
    "\n",
    "Note that camera-synced boxes are only populated for boxes that are visible in at least one camera. Every camera-synced box also has a LiDAR-synced box equivalent, but not vice versa. This is the case because the cameras collectively only span a horizontal field of view of ~230 degrees."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jCaqx7Wy51np"
   },
   "outputs": [],
   "source": [
    "lidar_synced_boxes = [lbl.box for lbl in frame.laser_labels]\n",
    "camera_synced_boxes = [\n",
    "    lbl.camera_synced_box\n",
    "    for lbl in frame.laser_labels\n",
    "    if lbl.camera_synced_box.ByteSize()\n",
    "]\n",
    "print(\n",
    "    'Frame contains %s LiDAR-synced boxes, of which %s also have a camera-synced box equivalent.'\n",
    "    % (len(lidar_synced_boxes), len(camera_synced_boxes)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1jcW9eN7zx5a"
   },
   "source": [
    "Along with their camera-synced box, we also provide `most_visible_camera`, the name of the camera where the object is most visible in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "O52R4jJC6j2l"
   },
   "outputs": [],
   "source": [
    "for lbl in frame.laser_labels:\n",
    "  if not lbl.camera_synced_box.ByteSize():\n",
    "    continue\n",
    "\n",
    "  offset = np.linalg.norm((lbl.box.center_x - lbl.camera_synced_box.center_x,\n",
    "                           lbl.box.center_y - lbl.camera_synced_box.center_y,\n",
    "                           lbl.box.center_z - lbl.camera_synced_box.center_z))\n",
    "  print(\n",
    "      'Label %s has an offset of %sm between box and camera_synced_box, and is most visible in camera %s.'\n",
    "      % (lbl.id, round(offset, 3), lbl.most_visible_camera_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LYOAtTfMZgs9"
   },
   "source": [
    "## Visualize projected_lidar_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5Ns9Vw3NbKUH"
   },
   "source": [
    "Here, we visualize `projected_lidar_labels`, which contains pre-projected 2D bounding boxes of all laser labels into all available camera images.\n",
    "\n",
    "Note that unlike `camera_labels`, which are only associated with the corresponding `laser_labels` for certain object types, there is always a correspondence between `projected_lidar_labels` and `laser_labels`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JodhPWuoZJum"
   },
   "outputs": [],
   "source": [
    "def show_projected_lidar_labels(camera_image, ax):\n",
    "  \"\"\"Displays pre-projected 3D laser labels.\"\"\"\n",
    "\n",
    "  for projected_labels in frame.projected_lidar_labels:\n",
    "    # Ignore camera labels that do not correspond to this camera.\n",
    "    if projected_labels.name != camera_image.name:\n",
    "      continue\n",
    "\n",
    "    # Iterate over the individual labels.\n",
    "    for label in projected_labels.labels:\n",
    "      # Draw the bounding box.\n",
    "      rect = patches.Rectangle(\n",
    "          xy=(label.box.center_x - 0.5 * label.box.length,\n",
    "              label.box.center_y - 0.5 * label.box.width),\n",
    "          width=label.box.length,\n",
    "          height=label.box.width,\n",
    "          linewidth=1,\n",
    "          edgecolor=(0.0, 1.0, 0.0, 1.0),  # green\n",
    "          facecolor=(0.0, 1.0, 0.0, 0.1))  # opaque green\n",
    "      ax.add_patch(rect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FZ-PoLgFdCRp"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(25, 20))\n",
    "\n",
    "for index, image in enumerate(frame.images):\n",
    "  ax = show_camera_image(image, [3, 3, index + 1])\n",
    "  show_projected_lidar_labels(image, ax)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Zaqk7X6aAsQw"
   },
   "source": [
    "## Visualize camera_synced_box Projections"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0iSY2szjCvqe"
   },
   "source": [
    "Here, we visualize projections of `camera_synced_box` onto the camera images. While the cameras have a rolling shutter, we use a global shutter projection, since rolling shutter effects were already factored in when creating the `camera_synced_box` labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0tH0chO7ZGkP"
   },
   "outputs": [],
   "source": [
    "def draw_2d_box(ax, u, v, color, linewidth=1):\n",
    "  \"\"\"Draws 2D bounding boxes as rectangles onto the given axis.\"\"\"\n",
    "  rect = patches.Rectangle(\n",
    "      xy=(u.min(), v.min()),\n",
    "      width=u.max() - u.min(),\n",
    "      height=v.max() - v.min(),\n",
    "      linewidth=linewidth,\n",
    "      edgecolor=color,\n",
    "      facecolor=list(color) + [0.1])  # Add alpha for opacity\n",
    "  ax.add_patch(rect)\n",
    "\n",
    "\n",
    "def draw_3d_wireframe_box(ax, u, v, color, linewidth=3):\n",
    "  \"\"\"Draws 3D wireframe bounding boxes onto the given axis.\"\"\"\n",
    "  # List of lines to interconnect. Allows for various forms of connectivity.\n",
    "  # Four lines each describe bottom face, top face and vertical connectors.\n",
    "  lines = ((0, 1), (1, 2), (2, 3), (3, 0), (4, 5), (5, 6), (6, 7), (7, 4),\n",
    "           (0, 4), (1, 5), (2, 6), (3, 7))\n",
    "\n",
    "  for (point_idx1, point_idx2) in lines:\n",
    "    line = plt.Line2D(\n",
    "        xdata=(int(u[point_idx1]), int(u[point_idx2])),\n",
    "        ydata=(int(v[point_idx1]), int(v[point_idx2])),\n",
    "        linewidth=linewidth,\n",
    "        color=list(color) + [0.5])  # Add alpha for opacity\n",
    "    ax.add_line(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5TkO8YUcZNl6"
   },
   "outputs": [],
   "source": [
    "def project_vehicle_to_image(vehicle_pose, calibration, points):\n",
    "  \"\"\"Projects from vehicle coordinate system to image with global shutter.\n",
    "\n",
    "  Arguments:\n",
    "    vehicle_pose: Vehicle pose transform from vehicle into world coordinate\n",
    "      system.\n",
    "    calibration: Camera calibration details (including intrinsics/extrinsics).\n",
    "    points: Points to project of shape [N, 3] in vehicle coordinate system.\n",
    "\n",
    "  Returns:\n",
    "    Array of shape [N, 3], with the latter dimension composed of (u, v, ok).\n",
    "  \"\"\"\n",
    "  # Transform points from vehicle to world coordinate system (can be\n",
    "  # vectorized).\n",
    "  pose_matrix = np.array(vehicle_pose.transform).reshape(4, 4)\n",
    "  world_points = np.zeros_like(points)\n",
    "  for i, point in enumerate(points):\n",
    "    cx, cy, cz, _ = np.matmul(pose_matrix, [*point, 1])\n",
    "    world_points[i] = (cx, cy, cz)\n",
    "\n",
    "  # Populate camera image metadata. Velocity and latency stats are filled with\n",
    "  # zeroes.\n",
    "  extrinsic = tf.reshape(\n",
    "      tf.constant(list(calibration.extrinsic.transform), dtype=tf.float32),\n",
    "      [4, 4])\n",
    "  intrinsic = tf.constant(list(calibration.intrinsic), dtype=tf.float32)\n",
    "  metadata = tf.constant([\n",
    "      calibration.width,\n",
    "      calibration.height,\n",
    "      open_dataset.CameraCalibration.GLOBAL_SHUTTER,\n",
    "  ],\n",
    "                         dtype=tf.int32)\n",
    "  camera_image_metadata = list(vehicle_pose.transform) + [0.0] * 10\n",
    "\n",
    "  # Perform projection and return projected image coordinates (u, v, ok).\n",
    "  return py_camera_model_ops.world_to_image(extrinsic, intrinsic, metadata,\n",
    "                                            camera_image_metadata,\n",
    "                                            world_points).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NAX3z67BZEQ0"
   },
   "outputs": [],
   "source": [
    "def show_projected_camera_synced_boxes(camera_image, ax, draw_3d_box=False):\n",
    "  \"\"\"Displays camera_synced_box 3D labels projected onto camera.\"\"\"\n",
    "  # Fetch matching camera calibration.\n",
    "  calibration = next(cc for cc in frame.context.camera_calibrations\n",
    "                     if cc.name == camera_image.name)\n",
    "\n",
    "  for label in frame.laser_labels:\n",
    "    box = label.camera_synced_box\n",
    "\n",
    "    if not box.ByteSize():\n",
    "      continue  # Filter out labels that do not have a camera_synced_box.\n",
    "    if (FILTER_AVAILABLE and not label.num_top_lidar_points_in_box) or (\n",
    "        not FILTER_AVAILABLE and not label.num_lidar_points_in_box):\n",
    "      continue  # Filter out likely occluded objects.\n",
    "\n",
    "    # Retrieve upright 3D box corners.\n",
    "    box_coords = np.array([[\n",
    "        box.center_x, box.center_y, box.center_z, box.length, box.width,\n",
    "        box.height, box.heading\n",
    "    ]])\n",
    "    corners = box_utils.get_upright_3d_box_corners(\n",
    "        box_coords)[0].numpy()  # [8, 3]\n",
    "\n",
    "    # Project box corners from vehicle coordinates onto the image.\n",
    "    projected_corners = project_vehicle_to_image(frame.pose, calibration,\n",
    "                                                 corners)\n",
    "    u, v, ok = projected_corners.transpose()\n",
    "    ok = ok.astype(bool)\n",
    "\n",
    "    # Skip object if any corner projection failed. Note that this is very\n",
    "    # strict and can lead to exclusion of some partially visible objects.\n",
    "    if not all(ok):\n",
    "      continue\n",
    "    u = u[ok]\n",
    "    v = v[ok]\n",
    "\n",
    "    # Clip box to image bounds.\n",
    "    u = np.clip(u, 0, calibration.width)\n",
    "    v = np.clip(v, 0, calibration.height)\n",
    "\n",
    "    if u.max() - u.min() == 0 or v.max() - v.min() == 0:\n",
    "      continue\n",
    "\n",
    "    if draw_3d_box:\n",
    "      # Draw approximate 3D wireframe box onto the image. Occlusions are not\n",
    "      # handled properly.\n",
    "      draw_3d_wireframe_box(ax, u, v, (1.0, 1.0, 0.0))\n",
    "    else:\n",
    "      # Draw projected 2D box onto the image.\n",
    "      draw_2d_box(ax, u, v, (1.0, 1.0, 0.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Sx2Gn386ZXs5"
   },
   "source": [
    "### 3D Wireframe Boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cI2uY5j0E09s"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(25, 20))\n",
    "\n",
    "for index, image in enumerate(frame.images):\n",
    "  ax = show_camera_image(image, [3, 3, index + 1])\n",
    "  show_projected_camera_synced_boxes(image, ax, draw_3d_box=True)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gm0lcOtmZZfm"
   },
   "source": [
    "### 2D Boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KPJHQYweZayf"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(25, 20))\n",
    "\n",
    "for index, image in enumerate(frame.images):\n",
    "  ax = show_camera_image(image, [3, 3, index + 1])\n",
    "  show_projected_camera_synced_boxes(image, ax, draw_3d_box=False)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2Me302RHe0ox"
   },
   "source": [
    "## Comparison"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xhy3E1bmisJt"
   },
   "source": [
    "This visualization compares `projected_lidar_labels` (green) with our global-shutter projection of `camera_synced_box` (yellow). The results are similar, but some discrepancies can be observed. This is due to the fact that `projected_lidar_labels` are based on rolling shutter-projected LiDAR-synced `box` labels.\n",
    "\n",
    "We recommend using the already provided `projected_lidar_labels` if corresponding 2D boxes are needed.\n",
    "\n",
    "For more details, please also refer to the [challenge webpage](https://waymo.com/intl/en_us/open/challenges/2022/3d-camera-only-detection/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vDerTC_ZeLjW"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(25, 20))\n",
    "\n",
    "for index, image in enumerate(frame.images):\n",
    "  ax = show_camera_image(image, [3, 3, index + 1])\n",
    "  show_projected_lidar_labels(image, ax)\n",
    "  show_projected_camera_synced_boxes(image, ax, draw_3d_box=False)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jIuIEITJWQ7A"
   },
   "source": [
    "# Create Submission"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YR7wzBD1S_Th"
   },
   "source": [
    "To participate in the challenge, a submission file containing predictions for the validation or test set must be assembled and submitted to the [challenge website](https://waymo.com/intl/en_us/open/challenges/2022/3d-camera-only-detection/). The evaluation server will compute detailed metrics and add them to the leaderboard.\n",
    "\n",
    "The cells below exemplify how to create a submission file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4ZK81lX2ZPjf"
   },
   "outputs": [],
   "source": [
    "def make_inference_objects(context_name, timestamp, boxes, classes, scores):\n",
    "  \"\"\"Create objects based on inference results of a frame.\n",
    "\n",
    "  Args:\n",
    "    context_name: The context name of the segment.\n",
    "    timestamp: The timestamp of the frame.\n",
    "    boxes: A [N, 7] float numpy array that describe the inferences boxes of the\n",
    "      frame, assuming each row is of the form [center_x, center_y, center_z,\n",
    "      length, width, height, heading].\n",
    "    classes: A [N] numpy array that describe the inferences classes. See\n",
    "      label_pb2.Label.Type for the class values. TYPE_VEHICLE = 1;\n",
    "      TYPE_PEDESTRIAN = 2; TYPE_SIGN = 3; TYPE_CYCLIST = 4;\n",
    "    scores: A [N] float numpy array that describe the detection scores.\n",
    "\n",
    "  Returns:\n",
    "    A list of metrics_pb2.Object.\n",
    "  \"\"\"\n",
    "  objects = []\n",
    "  for i in range(boxes.shape[0]):\n",
    "    x, y, z, l, w, h, heading = boxes[i]\n",
    "    cls = classes[i]\n",
    "    score = scores[i]\n",
    "    objects.append(\n",
    "        metrics_pb2.Object(\n",
    "            object=label_pb2.Label(\n",
    "                box=label_pb2.Label.Box(\n",
    "                    center_x=x,\n",
    "                    center_y=y,\n",
    "                    center_z=z,\n",
    "                    length=l,\n",
    "                    width=w,\n",
    "                    height=h,\n",
    "                    heading=heading),\n",
    "                type=label_pb2.Label.Type.Name(cls),\n",
    "                id=f'{cls}_{i}'),\n",
    "            score=score,\n",
    "            context_name=context_name,\n",
    "            frame_timestamp_micros=timestamp))\n",
    "  return objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Wn_Q2KuoWTdk"
   },
   "outputs": [],
   "source": [
    "# Prepare predictions. Please modify accordingly to process your inference results.\n",
    "context_names = ['1305342127382455702_3720_000_3740_000']\n",
    "\n",
    "frame_timestamps = {\n",
    "    # Please make sure that the timestamps match frame.timestamp_micros.\n",
    "    '1305342127382455702_3720_000_3740_000': [1511019682029265, 1511019682129243]\n",
    "}\n",
    "\n",
    "prediction_objects = {}\n",
    "for context_name in context_names:\n",
    "  prediction_objects[context_name] = {}\n",
    "  for timestamp in frame_timestamps[context_name]:\n",
    "    # Create objects based on inference results\n",
    "    prediction_objects[context_name][timestamp] = make_inference_objects(\n",
    "        context_name=context_name,\n",
    "        timestamp=timestamp,\n",
    "        boxes=np.random.rand(3, 7),\n",
    "        classes=np.random.randint(low=1, high=4, size=(3,)),\n",
    "        scores=np.random.rand(3,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "O39YXkbpb6VK"
   },
   "outputs": [],
   "source": [
    "context_name = context_names[0]\n",
    "timestamp = frame_timestamps[context_name][0]\n",
    "print(prediction_objects[context_name][timestamp][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3L4A8PfnZTCf"
   },
   "outputs": [],
   "source": [
    "# Pack to submission.\n",
    "num_submission_shards = 4  # Please modify accordingly.\n",
    "submission_file_base = '/tmp/MySubmission'  # Please modify accordingly.\n",
    "\n",
    "if not os.path.exists(submission_file_base):\n",
    "  os.makedirs(submission_file_base)\n",
    "sub_file_names = [\n",
    "    os.path.join(submission_file_base, part)\n",
    "    for part in [f'part{i}' for i in range(num_submission_shards)]\n",
    "]\n",
    "\n",
    "submissions = [\n",
    "    submission_pb2.Submission(inference_results=metrics_pb2.Objects())\n",
    "    for i in range(num_submission_shards)\n",
    "]\n",
    "\n",
    "obj_counter = 0\n",
    "for c_name, frames in prediction_objects.items():\n",
    "  for timestamp, objects in frames.items():\n",
    "    for obj in objects:\n",
    "      submissions[obj_counter %\n",
    "                  num_submission_shards].inference_results.objects.append(obj)\n",
    "      obj_counter += 1\n",
    "\n",
    "for i, shard in enumerate(submissions):\n",
    "  shard.task = submission_pb2.Submission.CAMERA_ONLY_DETECTION_3D\n",
    "  shard.authors[:] = ['A', 'B']  # Please modify accordingly.\n",
    "  shard.affiliation = 'Affiliation'  # Please modify accordingly.\n",
    "  shard.account_name = 'acc@domain.com'  # Please modify accordingly.\n",
    "  shard.unique_method_name = 'YourMethodName'  # Please modify accordingly.\n",
    "  shard.method_link = 'method_link'  # Please modify accordingly.\n",
    "  shard.description = ''  # Please modify accordingly.\n",
    "  shard.sensor_type = submission_pb2.Submission.CAMERA_ALL\n",
    "  shard.number_past_frames_exclude_current = 0  # Please modify accordingly.\n",
    "  shard.object_types[:] = [\n",
    "      label_pb2.Label.TYPE_VEHICLE, label_pb2.Label.TYPE_PEDESTRIAN,\n",
    "      label_pb2.Label.TYPE_CYCLIST\n",
    "  ]\n",
    "  with tf.io.gfile.GFile(sub_file_names[i], 'wb') as fp:\n",
    "    fp.write(shard.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ylh_PKzwe3K3"
   },
   "outputs": [],
   "source": [
    "print(submissions[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t6K4gUybfsG9"
   },
   "source": [
    "## Package submission\n",
    "```\n",
    "cd /tmp\n",
    "tar cvf MySubmission.tar MySubmission\n",
    "gzip MySubmission.tar\n",
    "```\n",
    "Then you can upload `/tmp/MySubmission.tar.gz` to the challenge website.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YjHbGx-AJkcT"
   },
   "source": [
    "# Compute Metrics\n",
    "\n",
    "We provide `metrics/tools/compute_detection_let_metrics_main` as a binary tool.\n",
    "Here, we provide a python example code for calculating LET metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oN4I20xrJmG3"
   },
   "outputs": [],
   "source": [
    "def build_let_metrics_config():\n",
    "  let_metric_config = metrics_pb2.Config.LongitudinalErrorTolerantConfig(\n",
    "      enabled=True,\n",
    "      sensor_location=metrics_pb2.Config.LongitudinalErrorTolerantConfig\n",
    "      .Location3D(x=1.43, y=0, z=2.18),\n",
    "      longitudinal_tolerance_percentage=0.1,  # 10% tolerance.\n",
    "      min_longitudinal_tolerance_meter=0.5,\n",
    "  )\n",
    "  config = metrics_pb2.Config(\n",
    "      box_type=label_pb2.Label.Box.TYPE_3D,\n",
    "      matcher_type=metrics_pb2.MatcherProto.TYPE_HUNGARIAN,\n",
    "      iou_thresholds=[0.0, 0.3, 0.5, 0.5, 0.5],\n",
    "      score_cutoffs=[i * 0.01 for i in range(100)] + [1.0],\n",
    "      let_metric_config=let_metric_config)\n",
    "\n",
    "  config.breakdown_generator_ids.append(breakdown_pb2.Breakdown.OBJECT_TYPE)\n",
    "  config.difficulties.append(\n",
    "      metrics_pb2.Difficulty(levels=[label_pb2.Label.LEVEL_2]))\n",
    "  config.breakdown_generator_ids.append(breakdown_pb2.Breakdown.CAMERA)\n",
    "  config.difficulties.append(\n",
    "      metrics_pb2.Difficulty(levels=[label_pb2.Label.LEVEL_2]))\n",
    "  config.breakdown_generator_ids.append(breakdown_pb2.Breakdown.RANGE)\n",
    "  config.difficulties.append(\n",
    "      metrics_pb2.Difficulty(levels=[label_pb2.Label.LEVEL_2]))\n",
    "  return config\n",
    "\n",
    "\n",
    "def compute_let_detection_metrics(prediction_frame_id,\n",
    "                                  prediction_bbox,\n",
    "                                  prediction_type,\n",
    "                                  prediction_score,\n",
    "                                  ground_truth_frame_id,\n",
    "                                  ground_truth_bbox,\n",
    "                                  ground_truth_type,\n",
    "                                  ground_truth_difficulty,\n",
    "                                  recall_at_precision=None,\n",
    "                                  name_filter=None,\n",
    "                                  config=build_let_metrics_config()):\n",
    "  \"\"\"Returns dict of metric name to metric values`.\n",
    "\n",
    "  Notation:\n",
    "    * M: number of predicted boxes.\n",
    "    * D: number of box dimensions. The number of box dimensions can be one of\n",
    "         the following:\n",
    "           4: Used for boxes with type TYPE_AA_2D (center_x, center_y, length,\n",
    "              width)\n",
    "           5: Used for boxes with type TYPE_2D (center_x, center_y, length,\n",
    "              width, heading).\n",
    "           7: Used for boxes with type TYPE_3D (center_x, center_y, center_z,\n",
    "              length, width, height, heading).\n",
    "    * N: number of ground truth boxes.\n",
    "\n",
    "  Args:\n",
    "    prediction_frame_id: [M] int64 tensor that identifies frame for each\n",
    "      prediction.\n",
    "    prediction_bbox: [M, D] tensor encoding the predicted bounding boxes.\n",
    "    prediction_type: [M] tensor encoding the object type of each prediction.\n",
    "    prediction_score: [M] tensor encoding the score of each prediciton.\n",
    "    ground_truth_frame_id: [N] int64 tensor that identifies frame for each\n",
    "      ground truth.\n",
    "    ground_truth_bbox: [N, D] tensor encoding the ground truth bounding boxes.\n",
    "    ground_truth_type: [N] tensor encoding the object type of each ground truth.\n",
    "    ground_truth_difficulty: [N] tensor encoding the difficulty level of each\n",
    "      ground truth.\n",
    "    config: The metrics config defined in protos/metrics.proto.\n",
    "\n",
    "  Returns:\n",
    "    A dictionary of metric names to metrics values.\n",
    "  \"\"\"\n",
    "  num_ground_truths = tf.shape(ground_truth_bbox)[0]\n",
    "  num_predictions = tf.shape(prediction_bbox)[0]\n",
    "  ground_truth_speed = tf.zeros((num_ground_truths, 2), tf.float32)\n",
    "  prediction_overlap_nlz = tf.zeros((num_predictions), tf.bool)\n",
    "\n",
    "  config_str = config.SerializeToString()\n",
    "  ap, aph, apl, pr, _, _, _ = py_metrics_ops.detection_metrics(\n",
    "      prediction_frame_id=tf.cast(prediction_frame_id, tf.int64),\n",
    "      prediction_bbox=tf.cast(prediction_bbox, tf.float32),\n",
    "      prediction_type=tf.cast(prediction_type, tf.uint8),\n",
    "      prediction_score=tf.cast(prediction_score, tf.float32),\n",
    "      prediction_overlap_nlz=prediction_overlap_nlz,\n",
    "      ground_truth_frame_id=tf.cast(ground_truth_frame_id, tf.int64),\n",
    "      ground_truth_bbox=tf.cast(ground_truth_bbox, tf.float32),\n",
    "      ground_truth_type=tf.cast(ground_truth_type, tf.uint8),\n",
    "      ground_truth_difficulty=tf.cast(ground_truth_difficulty, tf.uint8),\n",
    "      ground_truth_speed=ground_truth_speed,\n",
    "      config=config_str)\n",
    "  breakdown_names = config_util.get_breakdown_names_from_config(config)\n",
    "  metric_values = {}\n",
    "  for i, name in enumerate(breakdown_names):\n",
    "    if name_filter is not None and name_filter not in name:\n",
    "      continue\n",
    "    metric_values['{}/LET-mAP'.format(name)] = ap[i]\n",
    "    metric_values['{}/LET-mAPH'.format(name)] = aph[i]\n",
    "    metric_values['{}/LET-mAPL'.format(name)] = apl[i]\n",
    "  return metric_values\n",
    "\n",
    "\n",
    "def parse_metrics_objects_binary_files(ground_truths_path, predictions_path):\n",
    "  with tf.io.gfile.GFile(ground_truths_path, 'rb') as f:\n",
    "    ground_truth_objects = metrics_pb2.Objects.FromString(f.read())\n",
    "  with tf.io.gfile.GFile(predictions_path, 'rb') as f:\n",
    "    predictions_objects = metrics_pb2.Objects.FromString(f.read())\n",
    "  eval_dict = {\n",
    "      'prediction_frame_id': [],\n",
    "      'prediction_bbox': [],\n",
    "      'prediction_type': [],\n",
    "      'prediction_score': [],\n",
    "      'ground_truth_frame_id': [],\n",
    "      'ground_truth_bbox': [],\n",
    "      'ground_truth_type': [],\n",
    "      'ground_truth_difficulty': [],\n",
    "  }\n",
    "\n",
    "  # Parse and filter ground truths.\n",
    "  for obj in ground_truth_objects.objects:\n",
    "    # Ignore objects that are not in Cameras' FOV.\n",
    "    if not obj.object.most_visible_camera_name:\n",
    "      continue\n",
    "    # Ignore objects that are fully-occluded to cameras.\n",
    "    if obj.object.num_lidar_points_in_box == 0:\n",
    "      continue\n",
    "    # Fill in unknown difficulties.\n",
    "    if obj.object.detection_difficulty_level == label_pb2.Label.UNKNOWN:\n",
    "      obj.object.detection_difficulty_level = label_pb2.Label.LEVEL_2\n",
    "    eval_dict['ground_truth_frame_id'].append(obj.frame_timestamp_micros)\n",
    "    # Note that we use `camera_synced_box` for evaluation.\n",
    "    ground_truth_box = obj.object.camera_synced_box\n",
    "    eval_dict['ground_truth_bbox'].append(\n",
    "        np.asarray([\n",
    "            ground_truth_box.center_x,\n",
    "            ground_truth_box.center_y,\n",
    "            ground_truth_box.center_z,\n",
    "            ground_truth_box.length,\n",
    "            ground_truth_box.width,\n",
    "            ground_truth_box.height,\n",
    "            ground_truth_box.heading,\n",
    "        ], np.float32))\n",
    "    eval_dict['ground_truth_type'].append(obj.object.type)\n",
    "    eval_dict['ground_truth_difficulty'].append(\n",
    "        np.uint8(obj.object.detection_difficulty_level))\n",
    "\n",
    "  # Parse predictions.\n",
    "  for obj in predictions_objects.objects:\n",
    "    eval_dict['prediction_frame_id'].append(obj.frame_timestamp_micros)\n",
    "    prediction_box = obj.object.box\n",
    "    eval_dict['prediction_bbox'].append(\n",
    "        np.asarray([\n",
    "            prediction_box.center_x,\n",
    "            prediction_box.center_y,\n",
    "            prediction_box.center_z,\n",
    "            prediction_box.length,\n",
    "            prediction_box.width,\n",
    "            prediction_box.height,\n",
    "            prediction_box.heading,\n",
    "        ], np.float32))\n",
    "    eval_dict['prediction_type'].append(obj.object.type)\n",
    "    eval_dict['prediction_score'].append(obj.score)\n",
    "\n",
    "  for key, value in eval_dict.items():\n",
    "    eval_dict[key] = tf.stack(value)\n",
    "  return eval_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cDFgnc8V_FXk"
   },
   "outputs": [],
   "source": [
    "WAYMO_OPEN_DATASET_DIR = '/content/waymo_open_dataset'\n",
    "FAKE_GROUND_TRUTHS_BIN = (\n",
    "    WAYMO_OPEN_DATASET_DIR + '/metrics/tools/fake_ground_truths.bin')\n",
    "FAKE_PREDICTIONS_BIN = (\n",
    "    WAYMO_OPEN_DATASET_DIR + '/metrics/tools/fake_predictions.bin')\n",
    "\n",
    "eval_dict = parse_metrics_objects_binary_files(FAKE_GROUND_TRUTHS_BIN,\n",
    "                                               FAKE_PREDICTIONS_BIN)\n",
    "metrics_dict = compute_let_detection_metrics(**eval_dict)\n",
    "for key, value in metrics_dict.items():\n",
    "  if 'SIGN' in key:\n",
    "    continue\n",
    "  print(f'{key:<55}: {value}')"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "public_tutorial_camera_only.ipynb",
   "private_outputs": true,
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
